
import numpy as np
from .FFTBase import FFTBase
from pynfft import nfsft
from lie_learn.spaces.spherical_quadrature import estimate_spherical_quadrature_weights
from lie_learn.representations.SO3.irrep_bases import change_of_basis_matrix, change_of_basis_function

class S2FFT_NFFT(FFTBase):

    def __init__(self, L_max, x, w=None):
        """

        :param L_max: maximum spherical harmonic degree
        :param x: coordinates on spherical / spatial grid
        :param w: quadrature weights for the grid x
        """

        # If x is a list (generated by S2.meshgrid), convert to (M, 2) array
        if isinstance(x, list):
            x = np.c_[x[0].flatten()[:, None], x[1].flatten()[:, None]]

        # The NFSFT class can synthesis / analyze functions in terms of
        # NFFT-normalized, centered, complex spherical harmonics without Condon-Shortley phase.
        self._nfsft = nfsft.NFSFT(N=L_max, x=x)

        # Compute a change-of-basis matrix from the NFFT spherical harmonics to our prefered choice, the
        # quantum-normalized, centered, real spherical harmonics with Condon-Shortley phase.
        #TODO: change this to change_of_basis_function (test that it works..)
        #self._c2r = change_of_basis_matrix(np.arange(L_max + 1),
        #                                   frm=('complex', 'nfft', 'centered', 'nocs'),
        #                                   to=('real', 'quantum', 'centered', 'cs'))
        #self._r2c = change_of_basis_matrix(np.arange(L_max + 1),
        #                                   to=('complex', 'nfft', 'centered', 'nocs'),
        #                                   frm=('real', 'quantum', 'centered', 'cs'))
        #self._c = change_of_basis_matrix(np.arange(L_max + 1),
        #                                 frm=('real', 'nfft', 'centered', 'cs'),
        #                                 to=('complex', 'quantum', 'centered', 'nocs'))



        self._c2r_func = change_of_basis_function(np.arange(L_max + 1),
                                                  frm=('complex', 'nfft', 'centered', 'nocs'),
                                                  to=('real', 'quantum', 'centered', 'cs'))
        #self._r2c_func = change_of_basis_function(np.arange(L_max + 1),
        #                                          frm=('real', 'quantum', 'centered', 'cs'),
        #                                          to=('complex', 'nfft', 'centered', 'nocs'))

        # In the synthesize() function, we will need c2r.conj().T as a function (not a matrix).
        # It happens to be the case that the following is equal to c2r.conj().T:
        c2r_conj_T = change_of_basis_function(np.arange(L_max + 1),
                                              frm=('real', 'nfft', 'centered', 'cs'),
                                              to=('complex', 'quantum', 'centered', 'nocs'))
        self._c2r_T = lambda vec: c2r_conj_T(vec.conj()).conj()

        if w is None:
            # Precompute quadrature weights
            self.w = estimate_spherical_quadrature_weights(
                sampling_set=x, max_bandwidth=L_max,
                normalization='quantum', condon_shortley=True)[0]
        else:
            self.w = w.flatten()

        self.x = x
        self.L_max = L_max

    def analyze(self, f):

        # We want to perform the *weighted* adjoint FFT, so that we get the exact Fourier coefficients
        # (at least for a proper sampling grid such as Clenshaw-Curtis or Gauss-Legendre and the respective weights)
        # Hence, the function to be transformed is f * w
        self._nfsft.f = f * self.w

        # Expand the weighted function in terms of the conjugate of
        # NFFT-normalized, centered, complex spherical harmonics without Condon-Shortley phase:
        # a_lm = sum_i=0^M Y_lm(theta_i, phi_i).conj() * w_i * f(theta_i, phi_i)
        self._nfsft.adjoint()

        # The computed Fourier components a_lm are with respect to the basis of NFFT spherical harmonics,
        # so change the basis.
        # Let Y denote the M by (L_max+1)^2 matrix of NFFT spherical harmonics.
        # then a = Y.conj().T.dot(f * w), as computed by _nfsft.adjoint()
        # Since, Y.conj().T = r2c.conj().dot(R.T), we have a = r2c.conj().dot(R.T.dot(f * w))
        # To cancel the r2c.conj(), we multiply with c2r.conj()
        #a = self._c2r.conj().dot(self._nfsft.get_f_hat_flat()).real
        #b = self._c2r_func(self._nfsft.get_f_hat_flat().conj()).conj().real
        #print 'DIFF', np.sum(np.abs(a-b))
        #return self._c2r.conj().dot(self._nfsft.get_f_hat_flat()).real

        return self._c2r_func(self._nfsft.get_f_hat_flat().conj()).conj().real

    def synthesize(self, f_hat):
        # self._nfsft.trafo() computes the synthesis / forward transform using NFFT complex SH:
        # f = Y f_hat, where Y is the M by (L_max+1)^2 matrix of complex NFFT spherical harmonics.
        # We have R.T = c2r.dot(Y.T), so f = R.dot(f_hat) = Y.dot(c2r.T.dot(f_hat))
        #cfh = self._c2r.T.dot(f_hat)
        cfh = self._c2r_T(f_hat)
        self._nfsft.set_f_hat_flat(cfh)
        f = self._nfsft.trafo(use_dft=False, return_copy=True)
        return f.real